--- title: Visualization keywords: fastai sidebar: home_sidebar summary: "Functions designed to visualize how the model is performing on the dataset via saliency maps." ---
First, let's train a model for 3 epochs to have something reasonable for visualization.
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms
def get_tta_transforms(resize_shape, normalize_transform, n=5):
tta = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomResizedCrop((resize_shape, resize_shape)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor()
])
original_transform = transforms.Compose([
transforms.Resize((resize_shape, resize_shape)),
transforms.ToTensor()
])
return transforms.Compose([
transforms.Lambda(
lambda image: torch.stack(
[tta(image) for _ in range(n)] + [original_transform(image)]
)
),
transforms.Lambda(
lambda images: torch.stack([
normalize_transform(image) for image in images
])
),
])
def get_transforms(resize_shape, tta=False, tta_n=5):
random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
random_horizontal_flip = transforms.RandomHorizontalFlip()
resize = transforms.Resize((resize_shape, resize_shape))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_transforms = transforms.Compose([
random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
])
val_transforms = (
get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
else transforms.Compose([resize, transforms.ToTensor(), normalize])
)
return train_transforms, val_transforms
train_transform, val_transform = get_transforms(224, tta=True)
ds_mapping = initialize_datasets(
'/share/nikola/export/dt372/BreaKHis_v1/',
label='tumor_class', criterion=['tumor_type', 'magnification'],
split_transforms={'train': train_transform, 'val': val_transform}
)
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
model = resnet18(pretrained=True, num_classes=2, create_log_and_save_dirs=False)
if torch.cuda.is_available():
model = model.cuda()
mixup = True
num_epochs = 3
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
tr_loss, tr_acc = train(
model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
mixup=mixup, alpha=0.4, logging_frequency=25
)
val_loss, val_acc = validate(
model, epoch + 1, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
checkpoint_state(
model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
)
clear_logging_handlers()
Now, with our trained model, let's use non-random transforms for inference, and corresponding visualization.
resize = transforms.Resize((224, 224))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
inference_transform = transforms.Compose([resize, transforms.ToTensor(), normalize])
inference_ds = BreaKHisDataset.initalize(
'/share/nikola/export/dt372/BreaKHis_v1/', label='tumor_class',
criterion=['tumor_type', 'magnification'],
split={'all': 1.0},
split_transforms={'all': inference_transform}
)['all'].dataset
Here's an example of what one of our images looks like.
show_image(inference_ds[0])
get_preprocessed_image(inference_ds[0], inference_transform)
This is the main function for visualization. It will show an activation map using gradient-weighted activations from the last layer of the model (specifically, it's from the activations of layer4 for every ResNet. Note that by default, the activation map is shown based on how probable the model believes the label is correct. By specifying show_for_label as False and show_for_prediction as True, one can see the activation heatmap for why the model might believe something other than the label is correct.
Below, an example is shown when the above model is visualized on a benign and malignant example.
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=False)
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=True)
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=False)
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=True)
show_heatmap_and_original(
model, inference_ds[3], inference_transform, show_for_label=True, show_activation_grid=True
)
show_heatmap_and_original(
model, inference_ds[3], inference_transform, show_for_label=False, show_for_prediction=True,
show_activation_grid=True
)